"""
# 处理计数 -> 二值化
"""

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

def hist_plot(data, col, xlabel, ylabel, save_name):
    # 绘制收听次数直方图
    sns.set_style('whitegrid')
    fig, ax = plt.subplots()
    data[col].hist(ax=ax, bins=100)
    ax.set_yscale('log')
    ax.tick_params(labelsize=14)
    ax.set_xlabel(xlabel, fontsize=14)
    ax.set_ylabel(ylabel, fontsize=14)
    plt.savefig(save_name)
    plt.show()

# 加载数据集
triplets = pd.read_csv('../数据集/train_triplets.txt', header=None, names=['userID', 'songID', 'listenCount'],
                       delimiter='\t')
hist_plot(triplets, 'listenCount', 'Listen Count', 'Occurrence', './可视化/收听次数直方图.png')
# 收听次数二值化
triplets['listenCount'] = 1
hist_plot(triplets, 'listenCount', 'Listen Count', 'Occurrence', './可视化/收听次数直方图_二值化后.png')